#' Exact tests to detect mutually exclusive, co-occuring and altered genesets.
#'
#' @description Performs Pair-wise Fisher's Exact test to detect mutually exclusive or co-occuring events.
#' @details This function and plotting is inspired from genetic interaction analysis performed in the published study combining gene expression and mutation data in MDS. See reference for details.
#' @references Gerstung M, Pellagatti A, Malcovati L, et al. Combining gene mutation with gene expression data improves outcome prediction in myelodysplastic syndromes. Nature Communications. 2015;6:5901. doi:10.1038/ncomms6901.
#' @param maf an \code{\link{MAF}} object generated by \code{\link{read.maf}}
#' @param top check for interactions among top 'n' number of genes. Defaults to top 25. \code{genes}
#' @param genes List of genes among which interactions should be tested. If not provided, test will be performed between top 25 genes.
#' @param geneOrder Plot the results in given order. Default NULL.
#' @param pvalue Default c(0.05, 0.01) p-value threshold. You can provide two values for upper and lower threshold.
#' @param returnAll If TRUE returns test statistics for all pair of tested genes. Default FALSE, returns for only genes below pvalue threshold.
#' @param fontSize cex for gene names. Default 0.8
#' @param leftMar Left margin. Default 4
#' @param topMar Top margin. Default 4
#' @param showSigSymbols Default TRUE. Heighlight significant pairs
#' @param showCounts Default TRUE. Include number of events in the plot
#' @param countStats Default `all`. Can be `all` or `sig`
#' @param countType Default `cooccur`. Can be `all`, `cooccur`, `mutexcl`
#' @param countsFontSize Default 0.8
#' @param countsFontColor Default `black`
#' @param colPal colPalBrewer palettes. See RColorBrewer::display.brewer.all() for details
#' @param revPal Reverse the color palette. Default FALSE
#' @param showSum show [sum] with gene names in plot, Default TRUE
#' @param plotPadj Plot adj. p-values instead
#' @param colNC Number of different colors in the palette, minimum 3, default 9
#' @param nShiftSymbols shift if positive shift SigSymbols by n to the left, default = 5
#' @param sigSymbolsSize size of symbols in the matrix and in legend
#' @param sigSymbolsFontSize size of font in legends
#' @param pvSymbols vector of pch numbers for symbols of p-value for upper and lower thresholds c(upper, lower)
#' @param limitColorBreaks limit color to extreme values. Default TRUE
#' @examples
#' laml.maf <- system.file("extdata", "tcga_laml.maf.gz", package = "maftools")
#' laml <- read.maf(maf = laml.maf)
#' somaticInteractions(maf = laml, top = 5)
#' @return list of data.tables
#' @export

somaticInteractions = function(maf, top = 25, genes = NULL, pvalue = c(0.05, 0.01), returnAll = TRUE,
                               geneOrder = NULL, fontSize = 0.8, leftMar = 4, topMar = 4, showSigSymbols = TRUE,
                               showCounts = FALSE, countStats = 'all', countType = 'all',
                               countsFontSize = 0.8, countsFontColor = "black", colPal = "BrBG", revPal = FALSE, showSum = TRUE, plotPadj = FALSE, colNC=9, nShiftSymbols = 5, sigSymbolsSize=2,sigSymbolsFontSize=0.9, pvSymbols = c(46,42), limitColorBreaks = TRUE){
  #browser()
  if(is.null(genes)){
    genes = getGeneSummary(x = maf)[1:top, Hugo_Symbol]
  }

  if(length(genes) < 2){
    stop("Minimum two genes required!")
  }

  om = createOncoMatrix(m = maf, g = genes)
  all.tsbs = levels(getSampleSummary(x = maf)[,Tumor_Sample_Barcode])

  mutMat = t(om$numericMatrix)
  missing.tsbs = all.tsbs[!all.tsbs %in% rownames(mutMat)]

  if(nrow(mutMat) < 2){
    stop("Minimum two genes required!")
  }
  mutMat[mutMat > 0 ] = 1

  if(length(missing.tsbs) > 0){
    missing.tsbs = as.data.frame(matrix(data = 0, nrow = length(missing.tsbs), ncol = ncol(mutMat)),
                                 row.names = missing.tsbs)
    colnames(missing.tsbs) = colnames(mutMat)
    mutMat = rbind(mutMat, missing.tsbs)
  }

  #return(mutMat)

  #pairwise fisher test source code borrowed from: https://www.nature.com/articles/ncomms6901
  interactions = sapply(1:ncol(mutMat), function(i)
    sapply(1:ncol(mutMat), function(j) {
      f = try(fisher.test(mutMat[, i], mutMat[, j]), silent = TRUE)
      if (class(f) == "try-error"){
        if(all(mutMat[,i] == mutMat[,j])){
          if(colnames(mutMat)[i] != colnames(mutMat)[j]){
            warning("All the samples are in the same direction for the genes ", colnames(mutMat)[i], " and ",  colnames(mutMat)[j], "! Could not perform Fisher test.")
          }
          NA
        }else{
          if(colnames(mutMat)[i] != colnames(mutMat)[j]){
            warning("Contigency table could not created for the genes ", colnames(mutMat)[i], " and ",  colnames(mutMat)[j], "! Could not perform Fisher test.")
          }
          NA
        }
      }else{
        ifelse(f$estimate > 1,-log10(f$p.val), log10(f$p.val))
      }
    }))
  #return(interactions)
  oddsRatio <-
    oddsGenes <-
    sapply(1:ncol(mutMat), function(i)
      sapply(1:ncol(mutMat), function(j) {
        f = try(fisher.test(mutMat[, i], mutMat[, j]), silent = TRUE)
        if (class(f) == "try-error")
          if(all(mutMat[,i] == mutMat[,j])){
            NA
          }else{
            NA
          }
        else
          f$estimate
      }))
  rownames(interactions) = colnames(interactions) = rownames(oddsRatio) = colnames(oddsRatio) = colnames(mutMat)

  sigPairs = which(x = 10^-abs(interactions) < 1, arr.ind = TRUE)
  sigPairs2 = which(x = 10^-abs(interactions) >= 1, arr.ind = TRUE)

  if(nrow(sigPairs) < 1){
    stop("No meaningful interactions found.")
  }

  sigPairs = rbind(sigPairs, sigPairs2)
  sigPairsTbl = data.table::rbindlist(
                          lapply(X = seq_along(1:nrow(sigPairs)), function(i) {
                                  x = sigPairs[i,]
                                  g1 = rownames(interactions[x[1], x[2], drop = FALSE])
                                  g2 = colnames(interactions[x[1], x[2], drop = FALSE])
                                  #tbl = as.data.frame(table(apply(X = mutMat[,c(g1, g2), drop = FALSE], 1, paste, collapse = "")))
                                  tbl = as.data.frame(table(factor(apply(X = mutMat[,c(g1, g2), drop = FALSE], 1, paste, collapse = ""), levels = c("00", "01","11", "10"))))
                                  combn = data.frame(t(tbl$Freq))
                                  colnames(combn) = tbl$Var1
                                  pval = 10^-abs(interactions[x[1], x[2]])
                                  fest = oddsRatio[x[1], x[2]]
                                  d = data.table::data.table(gene1 = g1,
                                                         gene2 = g2,
                                                         pValue = pval, oddsRatio = fest)
                                  d = cbind(d, combn)
                                  d
                        }), fill = TRUE)

  sigPairsTbl[, pAdj := p.adjust(pValue, method = 'fdr')]
  sigPairsTbl[is.na(sigPairsTbl)] = 0
  sigPairsTbl$Event = ifelse(test = sigPairsTbl$oddsRatio > 1, yes = "Co_Occurence", no = "Mutually_Exclusive")
  sigPairsTbl$pair = apply(X = sigPairsTbl[,.(gene1, gene2)], MARGIN = 1, FUN = function(x) paste(sort(unique(x)), collapse = ", "))
  sigPairsTbl[,event_ratio := `01`+`10`]
  sigPairsTbl[,event_ratio := paste0(`11`, '/', event_ratio)]
  sigPairsTblSig = sigPairsTbl[order(as.numeric(pValue))][!duplicated(pair)]

  if(plotPadj){
    sigPairsTblSig$pAdjLog = ifelse(sigPairsTblSig$oddsRatio > 1, yes = -log10(sigPairsTblSig$pAdj), no = log10(sigPairsTblSig$pAdj))
    interactionsFDR = data.table::dcast(data = sigPairsTblSig, gene1 ~ gene2, value.var = 'pAdjLog')
    data.table::setDF(interactionsFDR, rownames = interactionsFDR$gene1)
    interactionsFDR$gene1 = NULL
    interactions = interactionsFDR[rownames(interactions), colnames(interactions)]
    interactions = as.matrix(interactions)
    sigPairsTblSig$pAdjLog = NULL
  }

  sigPairsTblSig = sigPairsTblSig[!gene1 == gene2] #Remove diagonal elements

  #Source code borrowed from: https://www.nature.com/articles/ncomms6901
  if(nrow(interactions) >= 5){
    #interactions[10^-abs(interactions) > max(pvalue)] = 0
    diag(interactions) <- 0
    m <- nrow(interactions)
    n <- ncol(interactions)


    col_pal = RColorBrewer::brewer.pal(9, colPal)
    if(revPal){
      col_pal = rev(col_pal)
    }
    col_pal = grDevices::colorRampPalette(colors = col_pal)
    col_pal = col_pal(m*n-1)


    if(!is.null(geneOrder)){
      if(!all(rownames(interactions) %in% geneOrder)){
        stop("Genes in geneOrder does not match the genes used for analysis.")
      }
      interactions = interactions[geneOrder, geneOrder]
    }

    interactions[lower.tri(x = interactions, diag = TRUE)] = NA

    gene_sum = getGeneSummary(x = maf)[Hugo_Symbol %in% rownames(interactions), .(Hugo_Symbol, AlteredSamples)]
    data.table::setDF(gene_sum, rownames = as.character(gene_sum$Hugo_Symbol))
    gene_sum = gene_sum[rownames(interactions),]
    if(!all(rownames(gene_sum) == rownames(interactions))){
      stop(paste0("Row mismatches!"))
    }
    if(!all(rownames(gene_sum) == colnames(interactions))){
      stop(paste0("Column mismatches!"))
    }
    if(showSum){
      rownames(gene_sum) = paste0(apply(gene_sum, 1, paste, collapse = ' ['), ']')
    }

    par(bty="n", mar = c(1, leftMar, topMar, 2)+.1, las=2, fig = c(0, 1, 0, 1))

    # adjust breaks for colors according to predefined legend values
    breaks = NA
    if(limitColorBreaks){
      minLog10pval = 3
      breaks <- seq(-minLog10pval,minLog10pval,length.out=m*n+1)
      #replace extreme values with the predefined minLog10pval values (and avoid white colored squares)
      interactions4plot  = interactions
      interactions4plot[interactions4plot < (-minLog10pval)] = -minLog10pval
      interactions4plot[interactions4plot > minLog10pval] = minLog10pval
      interactions = interactions4plot
    }

    image(x=1:n, y=1:m, interactions, col = col_pal,
          xaxt="n", yaxt="n",
          xlab="",ylab="", xlim=c(0, n+1), ylim=c(0, n+1),
          breaks = seq(-3, 3, length.out = (nrow(interactions) * ncol(interactions))))

    abline(h=0:n+.5, col="white", lwd=.5)
    abline(v=0:n+.5, col="white", lwd=.5)

    mtext(side = 2, at = 1:m, text = rownames(gene_sum), cex = fontSize, font = 3)
    mtext(side = 3, at = 1:n, text = rownames(gene_sum), cex = fontSize, font = 3)
    #text(x = 1:m, y = rep(n+0.5, length(n)), labels = rownames(gene_sum), srt = 90, adj = 0, font = 3, cex = fontSize)

    if(showCounts){
      countStats = match.arg(arg = countStats, choices = c("all", "sig"))
      countType = match.arg(arg = countType, choices = c("all", "cooccur", "mutexcl"))

      if(countStats == 'sig'){
        w = arrayInd(which(10^-abs(interactions) < max(pvalue)), rep(m,2))
        for(i in 1:nrow(w)){
          g1 = rownames(interactions)[w[i, 1]]
          g2 = colnames(interactions)[w[i, 2]]
          g12 = paste(sort(c(g1, g2)), collapse = ', ')
          if(countType == 'all'){
            e = sigPairsTblSig[pValue < max(pvalue)][pair %in% g12, event_ratio]
          }else if(countType == 'cooccur'){
            e = sigPairsTblSig[pValue < max(pvalue)][Event %in% "Co_Occurence"][pair %in% g12, `11`]
          }else if(countType == 'mutexcl'){
            e = sigPairsTblSig[pValue < max(pvalue)][Event %in% "Mutually_Exclusive"][pair %in% g12, `11`]
          }
          if(length(e) == 0){
            e = 0
          }
          text(w[i,1], w[i,2], labels = e, font = 3, col = countsFontColor, cex = countsFontSize)
        }
      }else if(countStats == 'all'){
        w = arrayInd(which(10^-abs(interactions) < max(pvalue)), rep(m,2))
        w2 = arrayInd(which(10^-abs(interactions) >= max(pvalue)), rep(m,2))
        w = rbind(w, w2)
        #print(w)
        for(i in 1:nrow(w)){
          g1 = rownames(interactions)[w[i, 1]]
          g2 = colnames(interactions)[w[i, 2]]
          g12 = paste(sort(c(g1, g2)), collapse = ', ')
          if(countType == 'all'){
            e = sigPairsTblSig[pair %in% g12, event_ratio]
          }else if(countType == 'cooccur'){
            e = sigPairsTblSig[pair %in% g12, `11`]
          }else if(countType == 'mutexcl'){
            e = sigPairsTblSig[pair %in% g12, `01` + `10`]
          }
          if(length(e) == 0){
            e = 0
          }
          text(w[i,1], w[i,2], labels = e, font = 3, col = countsFontColor, cex = countsFontSize)
        }
      }
    }

    if(showSigSymbols){
      w = arrayInd(which(10^-abs(interactions) < min(pvalue)), rep(m,2))
      points(w, pch=pvSymbols[2], col="black", cex = sigSymbolsSize)
      #w = arrayInd(which(10^-abs(interactions) < max(pvalue)), rep(m,2))
      w = arrayInd(which((10^-abs(interactions) < max(pvalue)) & (10^-abs(interactions) > min(pvalue))), rep(m,2))
      points(w, pch=pvSymbols[1], col="black", cex = sigSymbolsSize)
    }

    if(showSigSymbols){
      points(x = n-nShiftSymbols, y = 0.7*n, pch = pvSymbols[2], cex = sigSymbolsSize) # "*"
      if(plotPadj){
        text(x = n-nShiftSymbols, y = 0.7*n, paste0(" fdr < ", min(pvalue)), pos=4, cex = sigSymbolsFontSize, adj = 0)
      }else{
        text(x = n-nShiftSymbols, y = 0.7*n, paste0(" P < ", min(pvalue)), pos=4, cex = sigSymbolsFontSize, adj = 0)
      }

      points(x = n-nShiftSymbols, y = 0.65*n, pch = pvSymbols[1], cex = sigSymbolsSize) # "."
      if(plotPadj){
        text(x = n-nShiftSymbols, y = 0.65*n, paste0(" fdr < ", max(pvalue)), pos=4, cex = sigSymbolsFontSize)
      }else{
        text(x = n-nShiftSymbols, y = 0.65*n, paste0(" P < ", max(pvalue)), pos=4, cex = sigSymbolsFontSize)
      }
    }

    #image(y = 1:8 +6, x=rep(n,2)+c(2,2.5)+1, z=matrix(c(1:8), nrow=1), col=brewer.pal(8,"PiYG"), add=TRUE)
    par(fig = c(0.4, 0.7, 0, 0.4), new = TRUE)
    image(
      x = c(0.8, 1),
      y = seq(0, 1, length.out = 200),
      z = matrix(seq(0,1,length.out = 200), nrow = 1),
      col = col_pal, xlim = c(0, 1), ylim = c(0, 1), axes = FALSE, xlab = NA, ylab = NA
    )

    #atLims = seq(nrow(interactions), 0.9*nrow(interactions), length.out = 7)
    atLims = seq(0, 1, length.out = 7)
    axis(side = 4, at = atLims,  tcl=-.15, labels =c("> 3 (Mutually exclusive)", 2, 1, 0, 1, 2, ">3 (Co-occurence)"), lwd=.5, cex.axis = sigSymbolsFontSize, line = 0.2)
    if(plotPadj){
      text(x = 0.4, y = 0.5, labels = "-log10(fdr)", srt = 90, cex = sigSymbolsFontSize, xpd = TRUE)
    }else{
      text(x = 0.4, y = 0.5, labels = "-log10(P-value)", srt = 90, cex = sigSymbolsFontSize, xpd = TRUE)
    }

    #mtext(side=4, at = median(atLims), "-log10 (p-value)", las=3, cex = 0.9, line = 2.5, font = 1)
  }

  if(!returnAll){
    sigPairsTblSig = sigPairsTblSig[pValue < min(pvalue)]
  }
  return(sigPairsTblSig)
}
